{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a2f88613",
   "metadata": {},
   "source": [
    "# Lesson 5: Leveraging Assistants API for SQL Databases"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35d64829",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2565494f-532f-4fcb-99b0-e50fe1cb6eef",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "from openai import AzureOpenAI\n",
    "import json\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e85ca259",
   "metadata": {},
   "source": [
    "## Import the helper function\n",
    "\n",
    "To access the ``Helper.py`` file, please go to the ``File`` menu and select ``Open...``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270ea973-e750-48d2-ad96-16c1a0406232",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "import Helper\n",
    "from Helper import get_positive_cases_for_state_on_date\n",
    "from Helper import get_hospitalized_increase_for_state_on_date"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b16b04b",
   "metadata": {},
   "source": [
    "## Launch the Assistant API"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d24b35d",
   "metadata": {},
   "source": [
    "**Note**: The pre-configured cloud resource grants you access to the Azure OpenAI GPT model. The key and endpoint provided below are intended for teaching purposes only. Your notebook environment is already set up with the necessary keys, which may differ from those used by the instructor during the filming."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3842d0a1-e2af-4fba-8131-ec5d8ba92927",
   "metadata": {
    "height": 285
   },
   "outputs": [],
   "source": [
    "client = AzureOpenAI(\n",
    "    api_key=os.getenv(\"AZURE_OPENAI_KEY\"),\n",
    "    api_version=\"2024-02-15-preview\",\n",
    "    azure_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\")\n",
    "    )\n",
    "\n",
    "# I) Create assistant\n",
    "assistant = client.beta.assistants.create(\n",
    "  instructions=\"\"\"You are an assistant answering questions \n",
    "                  about a Covid dataset.\"\"\",\n",
    "  model=\"gpt-4-1106\", \n",
    "  tools=Helper.tools_sql)\n",
    "\n",
    "# II) Create thread\n",
    "thread = client.beta.threads.create()\n",
    "print(thread)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16c4150e-dac3-4f25-9a02-4cbffad8d35b",
   "metadata": {
    "height": 149
   },
   "outputs": [],
   "source": [
    "# III) Add message\n",
    "message = client.beta.threads.messages.create(\n",
    "    thread_id=thread.id,\n",
    "    role=\"user\",\n",
    "    content=\"\"\"how many hospitalized people we had in Alaska\n",
    "               the 2021-03-05?\"\"\"\n",
    ")\n",
    "print(message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f89fc9b-cad5-452f-9d9f-d6b48f722fa9",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "messages = client.beta.threads.messages.list(\n",
    "  thread_id=thread.id\n",
    ")\n",
    "\n",
    "print(messages.model_dump_json(indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e402a07d-fd1c-454e-b2fb-b2c71fb76a6a",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "# IV) Run assistant on thread\n",
    "\n",
    "run = client.beta.threads.runs.create(\n",
    "  thread_id=thread.id,\n",
    "  assistant_id=assistant.id,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "025035a0",
   "metadata": {},
   "source": [
    "## Leverage the function calling with Assistants API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf1b35b9-f4fa-4b3f-ac38-54366c3230fd",
   "metadata": {
    "height": 914
   },
   "outputs": [],
   "source": [
    "import time\n",
    "from IPython.display import clear_output\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "status = run.status\n",
    "\n",
    "while status not in [\"completed\", \"cancelled\", \"expired\", \"failed\"]:\n",
    "    time.sleep(5)\n",
    "    run = client.beta.threads.runs.retrieve(\n",
    "        thread_id=thread.id,run_id=run.id\n",
    "    )\n",
    "    print(\"Elapsed time: {} minutes {} seconds\".format(\n",
    "        int((time.time() - start_time) // 60),\n",
    "        int((time.time() - start_time) % 60))\n",
    "         )\n",
    "    status = run.status\n",
    "    print(f'Status: {status}')\n",
    "    if (status==\"requires_action\"):\n",
    "        available_functions = {\n",
    "            \"get_positive_cases_for_state_on_date\": get_positive_cases_for_state_on_date,\n",
    "            \"get_hospitalized_increase_for_state_on_date\":get_hospitalized_increase_for_state_on_date\n",
    "        }\n",
    "\n",
    "        tool_outputs = []\n",
    "        for tool_call in run.required_action.submit_tool_outputs.tool_calls:\n",
    "            function_name = tool_call.function.name\n",
    "            function_to_call = available_functions[function_name]\n",
    "            function_args = json.loads(tool_call.function.arguments)\n",
    "            function_response = function_to_call(\n",
    "                state_abbr=function_args.get(\"state_abbr\"),\n",
    "                specific_date=function_args.get(\"specific_date\"),\n",
    "            )\n",
    "            print(function_response)\n",
    "            print(tool_call.id)\n",
    "            tool_outputs.append(\n",
    "                { \"tool_call_id\": tool_call.id,\n",
    "                 \"output\": str(function_response)\n",
    "                }\n",
    "            )\n",
    "\n",
    "        run = client.beta.threads.runs.submit_tool_outputs(\n",
    "          thread_id=thread.id,\n",
    "          run_id=run.id,\n",
    "          tool_outputs = tool_outputs\n",
    "        )\n",
    "\n",
    "\n",
    "messages = client.beta.threads.messages.list(\n",
    "  thread_id=thread.id\n",
    ")\n",
    "\n",
    "print(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9bc1379-0ee3-49c3-a954-b026da3aa672",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "print(messages.model_dump_json(indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65faedb6",
   "metadata": {},
   "source": [
    "## Add the code interpreter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6768848b-30f1-4cf5-9c60-a84ec5abc97f",
   "metadata": {
    "height": 404
   },
   "outputs": [],
   "source": [
    "file = client.files.create(\n",
    "  file=open(\"./data/all-states-history.csv\", \"rb\"),\n",
    "  purpose='assistants'\n",
    ")\n",
    "assistant = client.beta.assistants.create(\n",
    "  instructions=\"\"\"You are an assitant answering questions about\n",
    "                  a Covid dataset.\"\"\",\n",
    "  model=\"gpt-4-1106\", \n",
    "  tools=[{\"type\": \"code_interpreter\"}],\n",
    "  file_ids=[file.id])\n",
    "thread = client.beta.threads.create()\n",
    "print(thread)\n",
    "message = client.beta.threads.messages.create(\n",
    "    thread_id=thread.id,\n",
    "    role=\"user\",\n",
    "    content=\"\"\"how many hospitalized people we had in Alaska\n",
    "               the 2021-03-05?\"\"\"\n",
    ")\n",
    "print(message)\n",
    "run = client.beta.threads.runs.create(\n",
    "  thread_id=thread.id,\n",
    "  assistant_id=assistant.id,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41d278df-c8c9-4904-8ca5-4767a718d6bd",
   "metadata": {
    "height": 387
   },
   "outputs": [],
   "source": [
    "status = run.status\n",
    "start_time = time.time()\n",
    "while status not in [\"completed\", \"cancelled\", \"expired\", \"failed\"]:\n",
    "    time.sleep(5)\n",
    "    run = client.beta.threads.runs.retrieve(\n",
    "        thread_id=thread.id,\n",
    "        run_id=run.id\n",
    "    )\n",
    "    print(\"Elapsed time: {} minutes {} seconds\".format(\n",
    "        int((time.time() - start_time) // 60),\n",
    "        int((time.time() - start_time) % 60))\n",
    "         )\n",
    "    status = run.status\n",
    "    print(f'Status: {status}')\n",
    "    clear_output(wait=True)\n",
    "\n",
    "\n",
    "messages = client.beta.threads.messages.list(\n",
    "  thread_id=thread.id\n",
    ")\n",
    "\n",
    "print(messages.model_dump_json(indent=2))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
